import gfootball.env as football_env
from gfootball.env import observation_preprocessing
import gym
import numpy as np
import matplotlib.pyplot as plt


class GoogleFootballMultiAgentEnv(object):
    """An wrapper for GFootball to make it compatible with our codebase."""

    def __init__(self, dense_reward, dump_freq, render=False):
        self.nagents = 3  # only using 'academy_3_vs_1_with_keeper' level for the moment
        self.time_limit = 150
        self.time_step = 0
        self.obs_dim = 60  # for 3 vs 1
        self.predict_dim = 12
        self.dense_reward = dense_reward  # select whether to use dense reward
        self.leader_id = 0
        self.leader_obs = 0
        self.reward_times = 3

        self.env = football_env.create_environment(
            env_name='academy_3_vs_1_with_keeper',
            stacked=False,
            representation="simple115",
            rewards='scoring',
            logdir='/home/liuboyin/MARL/Diversity/video',
            render=render,
            write_video=True,
            write_goal_dumps=True,
            dump_frequency=dump_freq,
            number_of_left_players_agent_controls=self.nagents,
            number_of_right_players_agent_controls=0,
            channel_dimensions=(observation_preprocessing.SMM_WIDTH, observation_preprocessing.SMM_HEIGHT))

        obs_space_low = self.env.observation_space.low[0][:self.obs_dim]
        obs_space_high = self.env.observation_space.high[0][:self.obs_dim]

        self.action_space = [gym.spaces.Discrete(
            self.env.action_space.nvec[1]) for _ in range(self.nagents)]
        self.observation_space = [
            gym.spaces.Box(low=obs_space_low, high=obs_space_high, dtype=self.env.observation_space.dtype) for _ in
            range(self.nagents)
        ]

    def get_leader_obs(self, full_obs):
        leader_obs = []
        ball_leader = full_obs['ball'][:2].reshape(1, 2)
        gate_leader = np.array([1, 0]).reshape(1, 2)
        rel_ball = full_obs['left_team'][1:] - ball_leader
        rel_dst = np.linalg.norm(rel_ball, axis=-1)
        leader_id = np.argmin(rel_dst)
        # leader_id=0
        self.leader_id = leader_id
        leader_gate = np.array([1, 0]) - full_obs['left_team'][leader_id + 1]
        leader_obs.append(full_obs['left_team'][leader_id + 1])  # 领导球员位置2
        leader_obs.append(full_obs['left_team_direction'][leader_id + 1])  # 领导球员速度2
        leader_obs.append(rel_ball[leader_id])  # 领导球员相对球位置2
        leader_obs.append(rel_dst[leader_id][np.newaxis])  # 相对球距离1
        leader_obs.append(leader_gate)  #
        leader_obs.append(np.arctan2(leader_gate[1], leader_gate[0])[np.newaxis])  # 相对球距离1
        leader_obs.append(np.linalg.norm(leader_gate)[np.newaxis])  # 相对球距离1

        # leader_obs = np.concatenate(leader_obs)
        self.leader_obs = leader_obs  # 11

    def get_obs(self, index=-1):
        full_obs = self.env.unwrapped.observation()[0]
        simple_obs = []
        array = np.array
        norm = np.linalg.norm
        ball_x, ball_y, ball_z = full_obs['ball']
        ball_which_zone = self._encode_ball_which_zone(ball_x, ball_y)
        if full_obs['ball_owned_team'] == 0:
            ball_owned = [1, 0, 0]
        elif full_obs['ball_owned_team'] == 1:
            ball_owned = [0, 1, 0]
        else:
            ball_owned = [0, 0, 1]
        if index == -1:
            # global state, absolute position
            self.get_leader_obs(full_obs)

            #
            ball_relative = full_obs['ball'][:2] - full_obs['left_team'][-self.nagents:]
            ball_relative_dst = norm(ball_relative, axis=-1)
            ball_theta = np.arctan2(ball_relative[:,1].T, ball_relative[:,0].T)
            ball_state = np.concatenate((np.array(full_obs['ball']),
                                         np.array(ball_which_zone),
                                         np.array(ball_relative).reshape(-1),
                                         np.array(full_obs['ball_direction']),
                                         np.array(ball_owned),
                                         ball_relative_dst,
                                         ball_theta,
                                         np.array(norm(full_obs['ball_direction'][:2]) ).reshape(-1)))  # 24

            obs_left_team = full_obs['left_team'][-self.nagents:]
            obs_left_team_direction = full_obs['left_team_direction'][-self.nagents:]
            left_team_distance = norm(obs_left_team - array([1, 0]), axis=1, keepdims=True)
            left_team_speed = np.linalg.norm(obs_left_team_direction, axis=1, keepdims=True)
            left_team_tired = full_obs['left_team_tired_factor'][-self.nagents:].reshape(-1, 1)

            left_team_state = np.concatenate(
                (obs_left_team, obs_left_team_direction , left_team_speed , \
                 left_team_distance, left_team_tired), axis=1).reshape(-1)  # 21

            obs_right_team = np.array(full_obs['right_team'])
            obs_right_team_direction = np.array(full_obs['right_team_direction'])
            right_team_distance = np.linalg.norm(obs_right_team - array([1, 0]), axis=1, keepdims=True)
            right_team_speed = norm(obs_right_team_direction, axis=1, keepdims=True)
            right_team_tired = np.array(full_obs['right_team_tired_factor']).reshape(-1, 1)
            right_team_state = np.concatenate(
                (obs_right_team, obs_right_team_direction , right_team_speed , \
                 right_team_distance, right_team_tired), axis=1).reshape(-1)  # 14
            # simple_obs=ball_state

            simple_obs = np.concatenate((ball_state, left_team_state, right_team_state))
            # if simple_obs.shape[0]!=59:
            #     a=1
            #############


        else:
            # local state, relative position
            player_num = -self.nagents + index
            ego_position = full_obs['left_team'][player_num].reshape(-1)
            ego_direction = full_obs['left_team_direction'][player_num]
            ego_speed = norm(ego_direction)
            ego_tired = full_obs['left_team_tired_factor'][player_num]
            ego_gate = array([1, 0]) - ego_position
            ego_gate_dst = norm(ego_gate)
            ego_gate_ang = np.arctan2(ego_gate[1], ego_gate[0])
            ego_ball = full_obs['ball'][:2] - ego_position
            ego_ball_dst = norm(ego_ball)
            ego_ball_ang = np.arctan2(ego_ball[1], ego_ball[0])

            if index == 0:
                self.get_leader_obs(full_obs)
            leader_label = 1 if self.leader_id == index else 0

            player_state = np.concatenate(
                ([self.leader_id, leader_label], ego_position, ego_direction , ego_gate, ego_ball,
                 [ego_ball_dst, ego_ball_ang, ego_gate_dst, ego_gate_ang, ego_speed , ego_tired]))

            ball_state = np.concatenate((np.array(full_obs['ball']),
                                         np.array(ball_which_zone),
                                         np.array(full_obs['ball_direction']) ,
                                         np.array(ball_owned),
                                         norm(full_obs['ball_direction'][:2]).reshape(-1)))

            obs_left_team = np.delete(
                full_obs['left_team'][-self.nagents:], index, axis=0)
            obs_left_team_direction = np.delete(full_obs['left_team_direction'][-self.nagents:], index, axis=0)
            left_team_distance = norm(obs_left_team - ego_position, axis=1,
                                      keepdims=True)
            left_team_speed = np.linalg.norm(obs_left_team_direction, axis=1, keepdims=True)
            left_team_tired = np.delete(full_obs['left_team_tired_factor'][-self.nagents:], index, axis=0).reshape(-1,
                                                                                                                   1)

            left_team_state = np.concatenate(
                (obs_left_team, obs_left_team_direction , left_team_speed , \
                 left_team_distance, left_team_tired), axis=1).reshape(-1)

            #####################
            obs_right_team = np.array(full_obs['right_team'])
            obs_right_team_direction = np.array(full_obs['right_team_direction'])
            right_team_distance = np.linalg.norm(obs_right_team - ego_position, axis=1, keepdims=True)
            right_team_speed = norm(obs_right_team_direction, axis=1, keepdims=True)
            right_team_tired = np.array(full_obs['right_team_tired_factor']).reshape(-1, 1)
            right_team_state = np.concatenate(
                (obs_right_team, obs_right_team_direction , right_team_speed , \
                 right_team_distance, right_team_tired), axis=1).reshape(-1)
            simple_obs = np.concatenate((player_state, ball_state, left_team_state, right_team_state))

        ##############
        # simple_obs = np.concatenate(simple_obs + self.leader_obs)

        #############
        # (2+4)*2+4*2+3+3
        return simple_obs

    def get_global_state(self):
        return self.get_obs(-1)

    def reset(self):
        self.time_step = 0
        self.env.reset()
        obs = np.array([self.get_obs(i) for i in range(self.nagents)])

        return obs, self.get_global_state()

    def _encode_ball_which_zone(self, ball_x, ball_y):
        MIDDLE_X, PENALTY_X, END_X = 0.2, 0.64, 1.0
        PENALTY_Y, END_Y = 0.27, 0.42
        if (-END_X <= ball_x and ball_x < -PENALTY_X) and (-PENALTY_Y < ball_y and ball_y < PENALTY_Y):
            return [1.0, 0, 0, 0, 0, 0]
        elif (-END_X <= ball_x and ball_x < -MIDDLE_X) and (-END_Y < ball_y and ball_y < END_Y):
            return [0, 1.0, 0, 0, 0, 0]
        elif (-MIDDLE_X <= ball_x and ball_x <= MIDDLE_X) and (-END_Y < ball_y and ball_y < END_Y):
            return [0, 0, 1.0, 0, 0, 0]
        elif (PENALTY_X < ball_x and ball_x <= END_X) and (-PENALTY_Y < ball_y and ball_y < PENALTY_Y):
            return [0, 0, 0, 1.0, 0, 0]
        elif (MIDDLE_X < ball_x and ball_x <= END_X) and (-END_Y < ball_y and ball_y < END_Y):
            return [0, 0, 0, 0, 1.0, 0]
        else:
            return [0, 0, 0, 0, 0, 1.0]

    def check_if_done(self):
        cur_obs = self.env.unwrapped.observation()[0]
        ball_loc = cur_obs['ball']
        ours_loc = cur_obs['left_team'][-self.nagents:]

        if ball_loc[0] < 0 or any(ours_loc[:, 0] < 0):
            return True

        return False

    def step(self, actions):

        self.time_step += 1
        _, original_rewards, done, infos = self.env.step(actions)
        rewards = list(original_rewards)
        obs = np.array([self.get_obs(i) for i in range(self.nagents)])

        if self.time_step >= self.time_limit:
            done = True

        if self.check_if_done():
            done = True

        if sum(rewards) <= 0:
            return obs, self.get_global_state(), 0, done, infos
        return obs, self.get_global_state(), 60, done, infos

    def seed(self, seed):
        self.env.seed(seed)

    def close(self):
        self.env.close()

    def get_env_info(self):
        output_dict = {}
        output_dict['n_actions'] = self.action_space[0].n
        output_dict['obs_shape'] = self.obs_dim
        output_dict['n_agents'] = self.nagents
        output_dict['state_shape'] = 62
        output_dict['episode_limit'] = self.time_limit
        output_dict['predict_shape'] = self.predict_dim

        return output_dict


# def make_football_env(seed_dir, dump_freq=1000, representation='extracted', render=False):
def make_football_env_3vs1(dense_reward=False, dump_freq=1000, render=False):
    '''
    Creates a env object. This can be used similar to a gym
    environment by calling env.reset() and env.step().

    Some useful env properties:
        .observation_space  :   Returns the observation space for each agent
        .action_space       :   Returns the action space for each agent
        .nagents            :   Returns the number of Agents
    '''
    return GoogleFootballMultiAgentEnv(dense_reward, dump_freq, render)
